from pyspark.ml.clustering import KMeans, KMeansModel
from pyspark.ml.evaluation import ClusteringEvaluator
from utils.spark_util import col_validation, one_hot_encoding, feature_cols_labelling, \
    one_hot_encoding_p, col_name_check, get_model_info
import numpy as np


def predict(sc, dataframe, para_list, output):
    # --------------------prepare feature data------------------
    num_cols, cat_cols = col_validation(dataframe, para_list["feature_col"])

    model_save_path, feature_list_train = get_model_info(sc, para_list["model_id"])

    cat_col_dict = {}

    i = -1
    for col in feature_list_train:
        if "_OneHot/" in col:
            feature, category = col.split("_OneHot/")
            if feature in cat_col_dict.keys():
                cat_col_dict[cat_cols[i]].append(category)
            else:
                i += 1
                cat_col_dict[cat_cols[i]] = [category]
    if len(cat_cols) != 0:
        dataframe, message, encoded_cat_cols = one_hot_encoding_p(sc, dataframe, para_list["source"]
                                                                  , cat_cols, cat_col_dict)
        if len(message) != 0:
            output["message"].append(message)
    else:
        encoded_cat_cols = []

    feature_list = num_cols + encoded_cat_cols

    df = feature_cols_labelling(dataframe, feature_list)

    # --------------------load the model-----------------

    model = KMeansModel.load(model_save_path)

    # --------------------make predictions----------------
    df_pred = model.transform(df)
    # df_pred = cast_to_float(df_pred, "prediction").drop("features").drop("label")\
    # .drop("rawPrediction").drop("probability")

    new_col_name = col_name_check("_cluster_id_", df_pred.columns)

    df_pred = df_pred.withColumnRenamed("prediction", new_col_name)

    output_cols = []
    for feature in para_list["feature_col"]:
        output_cols.append(feature)
    output_cols.append(new_col_name)
    output["result"]["output_params"]["output_cols"] = output_cols
    if len(encoded_cat_cols) != 0:
        for col in encoded_cat_cols:
            df_pred = df_pred.drop(col)
    df_pred = df_pred.drop("features")
    return df_pred


def train(sc, dataframe, para_list, record, output):
    # --------------------prepare data------------------
    num_cols, cat_cols = col_validation(dataframe, para_list["feature_col"])

    if len(cat_cols) != 0:
        dataframe, message, encoded_cat_cols = one_hot_encoding(sc, dataframe, para_list["source"], cat_cols)
        if len(message) != 0:
            output["message"].append(message)
    else:
        encoded_cat_cols = []

    feature_list = num_cols + encoded_cat_cols
    df = feature_cols_labelling(dataframe, feature_list)

    # ---------------------model fit---------------------
    kmeans = KMeans().setK(para_list["k"]).setMaxIter(para_list["max_iter"]).setDistanceMeasure(para_list["model_type"])

    model = kmeans.fit(df)
    print("KM MODEL FITTED")
    # ---------------------get metrics---------------
    predictions = model.transform(df)
    # TODO only has one cluster
    # Evaluate clustering by computing Silhouette score
    evaluator = ClusteringEvaluator()
    silhouette = evaluator.evaluate(predictions)
    # print("Silhouette with squared euclidean distance = " + str(silhouette))
    cluster_centers = model.clusterCenters()

    train_summary = model.summary

    # ---------------------record info----------------------
    other_info = {
        "feature_list": feature_list,
        "feature_col": para_list["feature_col"],
        "silhouette": silhouette,
        "cluster_centers": np.array(cluster_centers).tolist(),
        "cluster_sizes": train_summary.clusterSizes,
        "numIter": train_summary.numIter,
        "trainingCost": train_summary.trainingCost
    }
    record["other_info"] = other_info
    print("RETURN")
    return model
